# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Convert widerface data to mindrecord format."""
import os
import copy
import json
import numpy as np
from mindspore.mindrecord import FileWriter
from mindvision.engine.class_factory import ClassFactory, ModuleType

@ClassFactory.register(ModuleType.DATASET)
class WiderFace2MindRecord:
    """
        convert widerface dataset to mind record file.

        Args:
            label_path (str) : The ann file dir.
            mindrecord_file (str) : The output mindrecord file
    """
    def __init__(self, label_path, mindrecord_dir=None, eval_imageid_file=None, is_training=True):
        self.label_path = label_path
        self.mindrecord_dir = mindrecord_dir
        self.images_list = []
        self.labels_list = []
        self.is_training = is_training
        self.eval_imageid_file = eval_imageid_file

    def creat_widerface_train_label(self):
        """Get train image path and annotation from Widerface."""
        f = open(self.label_path, 'r')
        lines = f.readlines()
        first = True
        labels = []
        for line in lines:
            line = line.rstrip()
            if line.startswith('#'):
                if first is True:
                    first = False
                else:
                    c_labels = copy.deepcopy(labels)
                    self.labels_list.append(c_labels)
                    labels.clear()
                # remove '# '
                path = line[2:]
                path = self.label_path.replace('label.txt', 'images/') + path
                assert os.path.exists(path), 'image path is not exists.'
                self.images_list.append(path)
            else:
                line = line.split(' ')
                label = [float(x) for x in line]
                labels.append(label)
        # add the last label
        self.labels_list.append(labels)
        # del bbox which width is zero or height is zero
        for i in range(len(self.labels_list) - 1, -1, -1):
            labels = self.labels_list[i]
            for j in range(len(labels) - 1, -1, -1):
                label = labels[j]
                if label[2] <= 0 or label[3] <= 0:
                    labels.pop(j)
            if not labels:
                self.images_list.pop(i)
                self.labels_list.pop(i)
            else:
                self.labels_list[i] = labels
        return self.images_list, self.labels_list

    def read_train_dataset(self, dataset):
        """Read train image path and annotation from Widerface."""
        image_files = []
        image_anno_dict = {}
        for i in range(len(dataset[0])):
            img = dataset[0][i]
            labels = dataset[1][i]
            anns = np.zeros((0, 15))
            if not labels:
                continue
            for _, label in enumerate(labels):
                ann = np.zeros((1, 15))
                # get bbox
                ann[0, 0:2] = label[0:2]  # x1, y1
                ann[0, 2] = label[0] + label[2]
                ann[0, 3] = label[1] + label[3]
                # get landmarks
                ann[0, 4:14] = label[4:6] + label[7:9] + label[10:12] + label[13:15] + label[16:18]
                # set flag
                if (ann[0, 4] < 0):
                    ann[0, 14] = -1
                else:
                    ann[0, 14] = 1
                anns = np.append(anns, ann, axis=0)
            image_files.append(img)
            image_anno_dict[img] = np.array(anns).astype(np.float32)
        return image_files, image_anno_dict

    def train_data_to_mindrecord_byte_image(self, prefix="widerface.mindrecord", file_num=8):
        """Create Train MindRecord file."""
        dataset = self.creat_widerface_train_label()
        mindrecord_path = os.path.join(self.mindrecord_dir, prefix)
        writer = FileWriter(mindrecord_path, file_num)
        image_files, image_anno_dict = self.read_train_dataset(dataset)
        widerface = {
            "image": {"type": "bytes"},
            "annotation": {"type": "float32", "shape": [-1, 15]},
        }
        writer.add_schema(widerface, "widerface")
        for image_name in image_files:
            with open(image_name, 'rb') as f:
                img = f.read()
            annos = np.array(image_anno_dict[image_name], dtype=np.float32)
            row = {"image": img, "annotation": annos}
            writer.write_raw_data([row])
        writer.commit()

    def read_eval_dataset(self, label_path):
        """Get test image path and annotation from Widerface."""
        f = open(label_path, 'r')
        json_prefix = self.eval_imageid_file
        lines = f.readlines()
        count = 0
        json_str = None
        x = {}
        image_files = []
        image_label_dict = {}
        for line in lines:
            line = line.rstrip()
            if line.startswith('#'):
                path = line[2:]  # clear # and space
                path = label_path.replace('label.txt', 'images/') + path  # image path
                assert os.path.exists(path), 'image path is not exists.'
                image_files.append(path)
                image_label_dict[path] = np.array(count)
                x[path] = count
                count = count + 1
                json_str = json.dumps(x, indent=2)
        if not os.path.exists('./' + json_prefix):
            print("not ok")
            with open('./' + json_prefix, 'w') as json_file:
                json_file.write(json_str)
        return image_files, image_label_dict

    def test_data_to_mindrecord_byte_image(self, prefix="widerface.mindrecord", file_num=1):
        """Create Test MindRecord file."""
        mindrecord_dir = self.mindrecord_dir
        mindrecord_path = os.path.join(mindrecord_dir, prefix)
        writer = FileWriter(mindrecord_path, file_num)
        image_files, image_anno_dict = self.read_eval_dataset(label_path=self.label_path)
        widerface_eval = {
            "image": {"type": "bytes"},
            "annotation": {"type": "int32", "shape": [-1, 1]}
        }
        writer.add_schema(widerface_eval, "widerface_eval")
        for image_name in image_files:
            with open(image_name, 'rb') as f:
                img = f.read()
            annos = np.array(image_anno_dict[image_name], dtype=np.int32)
            row = {"image": img, "annotation": annos}
            writer.write_raw_data([row])
        writer.commit()

    def __call__(self, prefix='widerface.mindrecord'):
        """ Write mindrecord file """
        mindrecord_path = os.path.join(self.mindrecord_dir, prefix)
        if not self.is_training:
            if os.path.exists(mindrecord_path):
                print("mindrecord test dataset is existed")
                return
            self.test_data_to_mindrecord_byte_image()
        else:
            if os.path.exists(mindrecord_path + "0"):
                print("mindrecord train dataset is existed")
                return
            self.train_data_to_mindrecord_byte_image()
